#! /usr/bin/env python3
###########################################################################
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
###########################################################################
import os
import os.path
import re
import subprocess
import sys
import time

_dev_dir = '/sys/bus/pci/devices/'
_dev_re = re.compile(r'(?:(\w+:)?(\w+):)?(\w+)(\.\w+)?\Z')
_full_dev_re = re.compile(r'\w+:\w+:\w+\.\w\Z')

def usage(f):
    f.write("Usage: %0 <port> <command>\n")
    f.write("  <port> is the PCI ID of the downstream port.\n")
    f.write("  <command> is one of:")
    for name in sorted(globals()):
        if name.startswith("cmd_"):
            f.write(" "+name[4:])
    f.write("\n")

def resolve_dev(dev):
    if os.path.exists(_dev_dir+dev):
        return dev

    m = _dev_re.match(dev)
    if not m:
        sys.stderr.write("Could not parse device ID '%s'.\n" % dev)
        sys.exit(1)

    id_dom = 0
    if m.group(1):
        id_dom = int(m.group(1)[:-1], 16)

    id_bus = 0
    if m.group(2):
        id_bus = int(m.group(2), 16)

    id_dev = int(m.group(3), 16)

    id_func = 0
    if m.group(4):
        id_func = int(m.group(4)[1:], 16)

    dev = ( "%04x:%02x:%02x.%x" %
            (id_dom, id_bus, id_dev, id_func) )

    if not os.path.exists(_dev_dir+dev):
        sys.stderr.write("Could not find device '%s'.\n" % dev)
        sys.exit(1)

    return dev

def get_state(dev):
    args = ['setpci', '-s', dev, 'CAP_EXP+0x10.w']
    out = subprocess.check_output(args)
    out = out.rstrip("\r\n")
    return int(out, 16)

def set_state(dev, val):
    args = ['setpci', '-s', dev,
            'CAP_EXP+0x10.w=%x' % val]
    subprocess.check_call(args)

def cmd_link_up(dev):
    s = get_state(dev)
    set_state(dev, s & ~0x10)
    time.sleep(0.1)

def cmd_sys_up(dev):
    f = open(_dev_dir + dev + '/rescan', 'w')
    f.write("1\n")
    f.close()

def cmd_up(dev):
    cmd_link_up(dev)
    cmd_sys_up(dev)

def cmd_sys_down(dev):
    d = _dev_dir + dev
    removed = []
    for x in os.listdir(d):
        if _full_dev_re.match(x):
            removed.append(x)
            f = open('%s/%s/remove' % (d, x), 'w')
            f.write("1\n")
            f.close()

    for x in removed:
        if os.path.exists('%s/%s' % (d, x)):
            time.sleep(0.1)

def cmd_link_down(dev):
    s = get_state(dev)
    set_state(dev, s | 0x10)

def cmd_down(dev):
    cmd_sys_down(dev)
    cmd_link_down(dev)

def main():
    if len(sys.argv) != 3:
        usage(sys.stderr)
        return 1

    (prog, dev, cmd) = sys.argv
    dev = resolve_dev(sys.argv[1])

    cmd_func = globals().get("cmd_"+cmd)

    if cmd_func == None:
        usage(sys.stderr)
        return 1

    print "Device:", dev

    cmd_func(dev)

    return 0

if __name__ == "__main__":
    sys.exit(main())
